{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 1.5 「ファインチューニング」で精度向上を実現する方法\n",
    "\n",
    "- 本ファイルでは、学習済みのVGGモデルを使用し、ファインチューニングでアリとハチの画像を分類するモデルを学習します\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 学習目標\n",
    "\n",
    "1.\tPyTorchでGPUを使用する実装コードを書けるようになる\n",
    "2.\t最適化手法の設定において、層ごとに異なる学習率を設定したファインチューニングを実装できるようになる\n",
    "3.\t学習したネットワークを保存・ロードできるようになる\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 事前準備\n",
    "\n",
    "- 1.4節で解説したAWS EC2 のGPUインスタンスを使用します\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# パッケージのimport\n",
    "import numpy as np\n",
    "import random\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "\n",
    "from torchvision import models\n",
    "\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 乱数のシードを設定\n",
    "torch.manual_seed(1234)\n",
    "np.random.seed(1234)\n",
    "random.seed(1234)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# DatasetとDataLoaderを作成"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "./data/hymenoptera_data/train/**/*.jpg\n",
      "./data/hymenoptera_data/val/**/*.jpg\n"
     ]
    }
   ],
   "source": [
    "# 1.3節で作成したクラスを同じフォルダにあるmake_dataset_dataloader.pyに記載して使用\n",
    "from utils.dataloader_image_classification import ImageTransform, make_datapath_list, HymenopteraDataset\n",
    "\n",
    "# アリとハチの画像へのファイルパスのリストを作成する\n",
    "train_list = make_datapath_list(phase=\"train\")\n",
    "val_list = make_datapath_list(phase=\"val\")\n",
    "\n",
    "# Datasetを作成する\n",
    "size = 224\n",
    "mean = (0.485, 0.456, 0.406)\n",
    "std = (0.229, 0.224, 0.225)\n",
    "train_dataset = HymenopteraDataset(\n",
    "    file_list=train_list, transform=ImageTransform(size, mean, std), phase='train')\n",
    "\n",
    "val_dataset = HymenopteraDataset(\n",
    "    file_list=val_list, transform=ImageTransform(size, mean, std), phase='val')\n",
    "\n",
    "\n",
    "# DataLoaderを作成する\n",
    "batch_size = 32\n",
    "\n",
    "train_dataloader = torch.utils.data.DataLoader(\n",
    "    train_dataset, batch_size=batch_size, shuffle=True)\n",
    "\n",
    "val_dataloader = torch.utils.data.DataLoader(\n",
    "    val_dataset, batch_size=batch_size, shuffle=False)\n",
    "\n",
    "# 辞書オブジェクトにまとめる\n",
    "dataloaders_dict = {\"train\": train_dataloader, \"val\": val_dataloader}\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ネットワークモデルの作成"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ネットワーク設定完了：学習済みの重みをロードし、訓練モードに設定しました\n"
     ]
    }
   ],
   "source": [
    "# 学習済みのVGG-16モデルをロード\n",
    "\n",
    "# VGG-16モデルのインスタンスを生成\n",
    "use_pretrained = True  # 学習済みのパラメータを使用\n",
    "net = models.vgg16(pretrained=use_pretrained)\n",
    "\n",
    "# VGG16の最後の出力層の出力ユニットをアリとハチの2つに付け替える\n",
    "net.classifier[6] = nn.Linear(in_features=4096, out_features=2)\n",
    "\n",
    "# 訓練モードに設定\n",
    "net.train()\n",
    "\n",
    "print('ネットワーク設定完了：学習済みの重みをロードし、訓練モードに設定しました')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 損失関数を定義"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 損失関数の設定\n",
    "criterion = nn.CrossEntropyLoss()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 最適化手法を設定"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "params_to_update_1に格納： features.0.weight\n",
      "params_to_update_1に格納： features.0.bias\n",
      "params_to_update_1に格納： features.2.weight\n",
      "params_to_update_1に格納： features.2.bias\n",
      "params_to_update_1に格納： features.5.weight\n",
      "params_to_update_1に格納： features.5.bias\n",
      "params_to_update_1に格納： features.7.weight\n",
      "params_to_update_1に格納： features.7.bias\n",
      "params_to_update_1に格納： features.10.weight\n",
      "params_to_update_1に格納： features.10.bias\n",
      "params_to_update_1に格納： features.12.weight\n",
      "params_to_update_1に格納： features.12.bias\n",
      "params_to_update_1に格納： features.14.weight\n",
      "params_to_update_1に格納： features.14.bias\n",
      "params_to_update_1に格納： features.17.weight\n",
      "params_to_update_1に格納： features.17.bias\n",
      "params_to_update_1に格納： features.19.weight\n",
      "params_to_update_1に格納： features.19.bias\n",
      "params_to_update_1に格納： features.21.weight\n",
      "params_to_update_1に格納： features.21.bias\n",
      "params_to_update_1に格納： features.24.weight\n",
      "params_to_update_1に格納： features.24.bias\n",
      "params_to_update_1に格納： features.26.weight\n",
      "params_to_update_1に格納： features.26.bias\n",
      "params_to_update_1に格納： features.28.weight\n",
      "params_to_update_1に格納： features.28.bias\n",
      "params_to_update_2に格納： classifier.0.weight\n",
      "params_to_update_2に格納： classifier.0.bias\n",
      "params_to_update_2に格納： classifier.3.weight\n",
      "params_to_update_2に格納： classifier.3.bias\n",
      "params_to_update_3に格納： classifier.6.weight\n",
      "params_to_update_3に格納： classifier.6.bias\n"
     ]
    }
   ],
   "source": [
    "# ファインチューニングで学習させるパラメータを、変数params_to_updateの1～3に格納する\n",
    "\n",
    "params_to_update_1 = []\n",
    "params_to_update_2 = []\n",
    "params_to_update_3 = []\n",
    "\n",
    "# 学習させる層のパラメータ名を指定\n",
    "update_param_names_1 = [\"features\"]\n",
    "update_param_names_2 = [\"classifier.0.weight\",\n",
    "                        \"classifier.0.bias\", \"classifier.3.weight\", \"classifier.3.bias\"]\n",
    "update_param_names_3 = [\"classifier.6.weight\", \"classifier.6.bias\"]\n",
    "\n",
    "# パラメータごとに各リストに格納する\n",
    "for name, param in net.named_parameters():\n",
    "    if update_param_names_1[0] in name:\n",
    "        param.requires_grad = True\n",
    "        params_to_update_1.append(param)\n",
    "        print(\"params_to_update_1に格納：\", name)\n",
    "\n",
    "    elif name in update_param_names_2:\n",
    "        param.requires_grad = True\n",
    "        params_to_update_2.append(param)\n",
    "        print(\"params_to_update_2に格納：\", name)\n",
    "\n",
    "    elif name in update_param_names_3:\n",
    "        param.requires_grad = True\n",
    "        params_to_update_3.append(param)\n",
    "        print(\"params_to_update_3に格納：\", name)\n",
    "\n",
    "    else:\n",
    "        param.requires_grad = False\n",
    "        print(\"勾配計算なし。学習しない：\", name)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 最適化手法の設定\n",
    "optimizer = optim.SGD([\n",
    "    {'params': params_to_update_1, 'lr': 1e-4},\n",
    "    {'params': params_to_update_2, 'lr': 5e-4},\n",
    "    {'params': params_to_update_3, 'lr': 1e-3}\n",
    "], momentum=0.9)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 学習・検証を実施"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# モデルを学習させる関数を作成\n",
    "\n",
    "\n",
    "def train_model(net, dataloaders_dict, criterion, optimizer, num_epochs):\n",
    "\n",
    "    # 初期設定\n",
    "    # GPUが使えるかを確認\n",
    "    device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "    print(\"使用デバイス：\", device)\n",
    "\n",
    "    # ネットワークをGPUへ\n",
    "    net.to(device)\n",
    "\n",
    "    # ネットワークがある程度固定であれば、高速化させる\n",
    "    torch.backends.cudnn.benchmark = True\n",
    "\n",
    "    # epochのループ\n",
    "    for epoch in range(num_epochs):\n",
    "        print('Epoch {}/{}'.format(epoch+1, num_epochs))\n",
    "        print('-------------')\n",
    "\n",
    "        # epochごとの訓練と検証のループ\n",
    "        for phase in ['train', 'val']:\n",
    "            if phase == 'train':\n",
    "                net.train()  # モデルを訓練モードに\n",
    "            else:\n",
    "                net.eval()   # モデルを検証モードに\n",
    "\n",
    "            epoch_loss = 0.0  # epochの損失和\n",
    "            epoch_corrects = 0  # epochの正解数\n",
    "\n",
    "            # 未学習時の検証性能を確かめるため、epoch=0の訓練は省略\n",
    "            if (epoch == 0) and (phase == 'train'):\n",
    "                continue\n",
    "\n",
    "            # データローダーからミニバッチを取り出すループ\n",
    "            for inputs, labels in tqdm(dataloaders_dict[phase]):\n",
    "\n",
    "                # GPUが使えるならGPUにデータを送る\n",
    "                inputs = inputs.to(device)\n",
    "                labels = labels.to(device)\n",
    "\n",
    "                # optimizerを初期化\n",
    "                optimizer.zero_grad()\n",
    "\n",
    "                # 順伝搬（forward）計算\n",
    "                with torch.set_grad_enabled(phase == 'train'):\n",
    "                    outputs = net(inputs)\n",
    "                    loss = criterion(outputs, labels)  # 損失を計算\n",
    "                    _, preds = torch.max(outputs, 1)  # ラベルを予測\n",
    "\n",
    "                    # 訓練時はバックプロパゲーション\n",
    "                    if phase == 'train':\n",
    "                        loss.backward()\n",
    "                        optimizer.step()\n",
    "\n",
    "                    # 結果の計算\n",
    "                    epoch_loss += loss.item() * inputs.size(0)  # lossの合計を更新\n",
    "                    # 正解数の合計を更新\n",
    "                    epoch_corrects += torch.sum(preds == labels.data)\n",
    "\n",
    "            # epochごとのlossと正解率を表示\n",
    "            epoch_loss = epoch_loss / len(dataloaders_dict[phase].dataset)\n",
    "            epoch_acc = epoch_corrects.double(\n",
    "            ) / len(dataloaders_dict[phase].dataset)\n",
    "\n",
    "            print('{} Loss: {:.4f} Acc: {:.4f}'.format(\n",
    "                phase, epoch_loss, epoch_acc))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "使用デバイス： cuda:0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 0/5 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/2\n",
      "-------------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5/5 [00:20<00:00,  5.13s/it]\n",
      "  0%|          | 0/8 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "val Loss: 0.7704 Acc: 0.4444\n",
      "Epoch 2/2\n",
      "-------------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 8/8 [00:25<00:00,  4.41s/it]\n",
      "  0%|          | 0/5 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.5257 Acc: 0.7243\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5/5 [00:02<00:00,  1.88it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "val Loss: 0.1765 Acc: 0.9542\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# 学習・検証を実行する\n",
    "num_epochs=2\n",
    "train_model(net, dataloaders_dict, criterion, optimizer, num_epochs=num_epochs)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 学習したネットワークを保存・ロード"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# PyTorchのネットワークパラメータの保存\n",
    "save_path = './weights_fine_tuning.pth'\n",
    "torch.save(net.state_dict(), save_path)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# PyTorchのネットワークパラメータのロード\n",
    "load_path = './weights_fine_tuning.pth'\n",
    "load_weights = torch.load(load_path)\n",
    "net.load_state_dict(load_weights)\n",
    "\n",
    "# GPU上で保存された重みをCPU上でロードする場合\n",
    "load_weights = torch.load(load_path, map_location={'cuda:0': 'cpu'})\n",
    "net.load_state_dict(load_weights)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "以上"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
